import haiku as hk
import jax
import jax.numpy as jnp

from diffgro.common.models.utils import act2fn, init_he_normal
from diffgro.common.models.helpers import Transformer, _layer_norm
from diffgro.common.models.diffusion import UNetDiffusion


################ Skill Predictor ################


class SkillPredictor(hk.Module):
    def __init__(
        self,
        skill_dim: int,
        emb_dim: int,
        n_heads: int,
        n_layers: int,
    ):
        super().__init__()
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.n_layers = n_layers

    def __call__(self, obs, lang):
        """
        Input
        obs: (B, seq_len, obs_dim)
        lang: (B, 1, lang_dim)

        Output
        skill_pred = (B, code_dim)

        """

        seq_len = obs.shape[1]
        # Observation & Language embedding
        obs_emb = hk.Linear(self.emb_dim, with_bias=False)(obs)
        lang_emb = hk.Linear(self.emb_dim, with_bias=False)(lang)

        # Time embedding
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        time_emb = hk.get_parameter("time_emb", [500, self.emb_dim], init=embed_init)

        # Prepare input
        obs_emb = obs_emb + time_emb[None, :seq_len]
        x = jnp.concatenate([lang_emb, obs_emb], axis=1)  # (B, seq_len + 1, emb_dim)
        x = _layer_norm(x)

        x = Transformer(
            num_heads=self.n_heads, num_layers=self.n_layers, dropout_rate=0.1
        )(x)[
            :, -1
        ]  # (B, emb_dim)
        skill_pred = hk.Linear(self.skill_dim, with_bias=False)(x)  # (B, code_dim)

        return skill_pred


class VectorQuantizer(hk.Module):
    def __init__(
        self,
        code_dim: int,
        n_codes: int,
        decay: float = 0.99,
    ):
        super().__init__()

        self.code_dim = code_dim
        self.n_codes = n_codes
        self.decay = decay

    def __call__(self, skill_pred, is_training=False):
        """
        Input
        skill_pred: (B, code_dim)

        Output
        skill_quant = (B, code_dim)

        """

        result = hk.nets.VectorQuantizerEMA(
            num_embeddings=self.n_codes,
            embedding_dim=self.code_dim,
            decay=self.decay,
            commitment_cost=0.25,
        )(skill_pred, is_training)

        vq_loss, skill_quant = result["loss"], result["quantize"]
        return vq_loss, skill_quant
